
from arguments import get_args
import os
from Record.file_management import read_obj_dumps, load_from_pickle, save_to_pickle, create_directory

def create_ROC_curve(args, model_path):
    import torch
    from ACState.object_dict import ObjDict
    from Buffer.train_test_buffers import generate_buffers
    from Model.base_model import InferenceModel
    from ActualCausal.train_loop import test_dataset
    from ACState.extractor import regenerate
    from ACState.compute_mean_variance import compute_encoder_mean_variance
    from Model.model_utils import save_model, load_model
    from ActualCausal.Utils.weighting import separate_weights, get_weights
    # from Causal.Training.full_test import test_full, test_full_train

    from Environment.Environments.initialize_environment import initialize_environment

    from Network.network_utils import pytorch_model
    import numpy as np
    import sys
    import psutil
    # generate ROC curve
    from sklearn.datasets import make_classification
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import roc_curve
    import torch
    # TODO: logic for loading data 
    # split into train/test sets
    torch.cuda.set_device(args.torch.gpu)
    np.set_printoptions(threshold=3000, linewidth=120, precision=4, suppress=True)
    torch.set_printoptions(precision=4, sci_mode=False)
    if args.torch.cuda:
        device = "cuda:" + str(args.torch.gpu)
    else:
        device = "cpu"
    environment, record = initialize_environment(args.environment, args.record)
    # initialize the state handling, and fills args.factor with appropriate values
    enc_range, enc_dyn = compute_encoder_mean_variance(args)
    extractor, normalization = regenerate(args, environment, all=True, encoding_dim = args.image_enc.encoding_dim if len(args.train.load_encodings) > 0 else -1, enc_rng = enc_range, enc_dyn = enc_dyn)
    # initialize the model
    model = InferenceModel(args, extractor, normalization, environment)
    model = load_model(model, model_path, device=device)

    # get the train and test buffers
    if len(args.record.load_intermediate) > 0: train_buffer, test_buffer = load_from_pickle(os.path.join(args.record.load_intermediate,environment.name + "_traintest.pkl"))
    else: train_buffer, test_buffer = generate_buffers(environment, args, extractor, normalization)
    if len(args.record.save_intermediate) > 0: save_to_pickle(os.path.join(create_directory(args.record.save_intermediate), environment.name +  "_traintest.pkl"), (train_buffer, test_buffer))

    values, weights, binaries = separate_weights(args, args.inter.weighting_type, model, test_buffer)
    weights = get_weights(args.active.weighting[0], binaries)
    
    # TODO: assumes only one key for train names, and only one inference for infer_names
    result = test_dataset(args, model, test_buffer, extractor, weights=weights)[args.inter.train_names[0]][args.infer.infer_types[0]]
    # predict only those with meaningful utrace
    result.utrace = result.utrace[:,0]
    useful_binaries = np.mean(result.utrace, axis=0)
    useful_binaries[useful_binaries > 0.999] = 0 # essentially always 1
    useful_binaries[useful_binaries < 0.001] = 0 # essentially always 0
    useful_binaries = useful_binaries.nonzero()[0]
    
    # retrieve just the probabilities for the positive class
    pos_probs, testy = [np.round(pytorch_model.unwrap(result.inter_masks[:,0,b])) for b in useful_binaries], [result.utrace[:,b] for b in useful_binaries]
    print(pos_probs[0].shape, testy[0].shape)
    pos_probs, testy = np.concatenate(pos_probs, axis=0), np.concatenate(testy, axis=0)
    # calculate roc curve for model
    fpr, tpr, _ = roc_curve(testy, pos_probs)
    # # plot no skill roc curve
    # plt.plot([0, 1], [0, 1], linestyle='--', label='No Skill')
    return fpr, tpr

def plot_roc(fprs, tprs, labels, target):
    import matplotlib.pyplot as plt

    # plot model roc curve
    for fpr, tpr, l in zip(fprs, tprs, labels):
        plt.plot(fpr, tpr, marker='.', label=l)
    # axis labels
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    # show the legend
    plt.legend()
    # save the plot
    plt.savefig(os.path.join(target, "_".join(labels) + ".pdf"))

if __name__ == "__main__":
    args = get_args()
    print(args) # print out args for records
    
    if args.ROC.generate_plot:
        fprs, tprs = load_from_pickle(os.path.join(args.record.load_dir, "fprtpr.pkl"))
        plot_roc(fprs, tprs, args.ROC.labels, args.record.save_dir)
    else:
        fprs, tprs = list(), list()
        for model_path in args.ROC.model_paths:
            fpr, tpr = create_ROC_curve(args, model_path)
            fprs.append(fpr)
            tprs.append(tpr)
        save_to_pickle(os.path.join(args.record.save_dir, "fprtpr.pkl"), (fprs, tprs))


